# coding: utf-8
# author: Han Zhang


import sys
sys.path.insert(0, "../lib/")
from LSTM_text_dependency import *

from keras.models import load_model
from keras.models import model_from_json
from keras.preprocessing import sequence
from keras.preprocessing import text
from keras_self_attention import SeqWeightedAttention

from scipy import stats
import random
import numpy as np
import json
import codecs
import os


import warnings
warnings.filterwarnings('ignore')
np.seterr(all='raise')


os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from keras import backend as K
K.set_image_dim_ordering('tf')


import tensorflow as tf
global graph

graph = tf.get_default_graph()


dirname = os.path.dirname(__file__)
filename = os.path.join(dirname, '../supporting/vocab_pos_grievance.dict')

with open(filename) as json_file:
	wordpos = json.load(json_file)


model_grievance = load_model("../modelfiles/weights_text-stage2.hdf5", custom_objects = {"SeqWeightedAttention":SeqWeightedAttention}) 
model_grievance._make_predict_function()


def predict_text_deep_protest_vs_grievance(text):
	""" this function takes a (segmented) text as input 
	and output the raw second stage predicted probability of this text discussing grievances vs. protests, from the second-stage classifier

	"""
	seq_grievance = string2sequence (text, wordpos)
	sm_g = np.array([seq_grievance])


	with graph.as_default():

		yg = model_grievance.predict(sm_g)


	return yg[0][0]


if __name__ == '__main__':

	print predict_text_deep_protest_vs_grievance("特警 被 围攻   八名 特警 对峙 二十名 手无寸铁 的 群众   却 打 不过   晕死   怎么 得 一个打 俩 啊   还 特警 来   我 呸         省 道".decode('utf-8'))
	# print predict_text_deep_protest_vs_grievance("山东 聊城 东昌府区 政府 采用 株连 公职人员 的 手段 促使 村民 签订 拆迁 协议 陈庄村 柳园 街道 陈庄村 四十余名 公职人员 因 亲属 没有 签订 拆迁 协议 而 遭 政府 通知 要 开除公职 另有 拆迁办 的 通告 称 不要 相信 律师 谗言 对多要 拆迁 补偿 抱有幻想".decode('utf-8'))
 
